#include <memory>
#include <string>
#include <vector>
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"

namespace caffe2 {
namespace {

template <class Context>
class LastNWindowCollectorOp : public Operator<Context> {
 public:
  USE_OPERATOR_CONTEXT_FUNCTIONS;
  LastNWindowCollectorOp(const OperatorDef& operator_def, Workspace* ws)
      : Operator<Context>(operator_def, ws),
        numToCollect_(
            OperatorBase::GetSingleArgument<int>("num_to_collect", -1)) {
    CAFFE_ENFORCE_GT(numToCollect_, 0);
  }

  bool RunOnDevice() override {
    auto* output = Output(LAST_N);
    const auto& input = Input(DATA);

    CAFFE_ENFORCE_GE(input.ndim(), 1);

    bool output_initialized = output->size() > 0;
    if (output_initialized) {
      CAFFE_ENFORCE_EQ(output->ndim(), input.ndim());
      for (size_t i = 1; i < input.ndim(); ++i) {
        CAFFE_ENFORCE_EQ(output->dim(i), input.dim(i));
      }
    }

    auto dims = input.dims();
    auto num_entries = dims[0];

    dims[0] = numToCollect_;
    output->Reserve(dims, &context_);

    if (num_entries == 0) {
      if (!output_initialized) {
        // Get both shape and meta
        output->CopyFrom(input, &context_);
      }
      return true;
    }

    auto num_to_copy = std::min<int32_t>(num_entries, numToCollect_);
    auto output_batch_size = output_initialized ? output->dim(0) : 0;
    dims[0] = std::min<size_t>(numToCollect_, output_batch_size + num_to_copy);
    if (output_batch_size < numToCollect_) {
      output->Resize(dims);
    }
    auto* output_data =
        static_cast<char*>(output->raw_mutable_data(input.meta()));

    auto* next = Output(NEXT);
    CAFFE_ENFORCE_EQ(0, next->ndim());
    auto* next_data = next->template mutable_data<int32_t>();
    CAFFE_ENFORCE_LT(*next_data, output->dim(0));

    auto block_size = input.size_from_dim(1);
    auto block_bytesize = block_size * input.itemsize();
    const auto* input_data = static_cast<const char*>(input.raw_data());

    if (num_entries > numToCollect_) {
      // just copy the last N rows
      context_.template CopyItems<Context, Context>(
          input.meta(),
          num_to_copy * block_size,
          input_data + (num_entries - numToCollect_) * block_bytesize,
          output_data);
      *next_data = 0;
      return true;
    }
    auto start = *next_data;
    auto first_chunk_size =
        std::min<size_t>(num_to_copy + start, numToCollect_) - start;
    context_.template CopyItems<Context, Context>(
        input.meta(),
        first_chunk_size * block_size,
        input_data,
        output_data + start * block_bytesize);

    context_.template CopyItems<Context, Context>(
        input.meta(),
        (num_to_copy - first_chunk_size) * block_size,
        input_data + first_chunk_size * block_bytesize,
        output_data);

    *next_data = (start + num_to_copy) % numToCollect_;

    return true;
  }

 private:
  const int32_t numToCollect_;

  INPUT_TAGS(LAST_N_IN, NEXT_IN, DATA);
  OUTPUT_TAGS(LAST_N, NEXT);
};

REGISTER_CPU_OPERATOR(LastNWindowCollector, LastNWindowCollectorOp<CPUContext>);

OPERATOR_SCHEMA(LastNWindowCollector)
    .NumInputs(3)
    .NumOutputs(2)
    .EnforceOneToOneInplace()
    .SetDoc(R"DOC(
Collect the last N rows from input data. The purpose is to keep track of data
accross batches, so for example suppose the LastNWindowCollector is called
successively with the following input data

[1,2,3,4]
[5,6,7]
[8,9,10,11]

And the number of items is set to 6, then the output after the 3rd call
will contain the following elements:
[6,7,8,9,10,11]

No guarantee is made on the ordering of elements in input. So a valid value for
output could have been
[11,10,9,8,7,6]

Also, this method works for any order tensor, treating the first dimension as
input rows and keeping the last N rows seen as input. So for instance:

[[1,2],[2,3],[3,4],[4,5]]
[[5,6],[6,7],[7,8]]
[[8,9],[9,10],[10,11],[11,12]]

A possible output would be
[[6,7],[7,8],[8,9],[9,10],[10,11],[11,12]]

This is not thread safe.
)DOC")
    .Arg(
        "num_to_collect",
        "The number of random samples to append for each positive samples")
    .Input(
        0,
        "last-N buffer",
        "The buffer for last-N record. Should be intialized to empty tensor")
    .Input(
        1,
        "next cursor",
        "The cursor pointing to the next positiion that should be replaced. "
        "Should be initialized to 0.")
    .Input(2, "DATA", "tensor to collect from")
    .Output(0, "last-N buffer", "Data stored in sessions")
    .Output(1, "next cursor", "Updated input cursor");
SHOULD_NOT_DO_GRADIENT(LastNWindowCollector);
}
}
